{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试分析提取伪量化算子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试伪量化卷积"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), int8], <span style=\"color: #AA22FF; font-weight: bold\">%</span>w: Tensor[(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">5</span>), int8]) {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #008000\">2</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>w, <span style=\"color: #008000\">0.5</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>, padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], kernel_size<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">5</span>]);\n",
       "  qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>quantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = relay.var(\"x\", shape=[1, 3, 224, 224], dtype=\"int8\")\n",
    "w = relay.var(\"w\", shape=[16, 3, 5, 5], dtype=\"int8\")\n",
    "zero = relay.const(0)\n",
    "\n",
    "op = relay.op.nn.conv2d(\n",
    "    relay.qnn.op.dequantize(x, relay.const(2.0), zero),\n",
    "    relay.qnn.op.dequantize(w, relay.const(0.5), zero),\n",
    "    kernel_size=[5, 5],\n",
    ")\n",
    "op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype=\"int8\")\n",
    "\n",
    "mod = tvm.IRModule.from_expr(op)\n",
    "mod.show()\n",
    "fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)\n",
    "\n",
    "assert dict(fake_quantized_op_freqs) == {\"nn.conv2d\": 1}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试伪量化 dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">64</span>), int8], <span style=\"color: #AA22FF; font-weight: bold\">%</span>w: Tensor[(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">64</span>), int8]) {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #008000\">2</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>w, <span style=\"color: #008000\">0.5</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dense(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>, units<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>);\n",
       "  qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>quantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = relay.var(\"x\", shape=[128, 64], dtype=\"int8\")\n",
    "w = relay.var(\"w\", shape=[256, 64], dtype=\"int8\")\n",
    "zero = relay.const(0)\n",
    "\n",
    "op = relay.op.nn.dense(\n",
    "    relay.qnn.op.dequantize(x, relay.const(2.0), zero),\n",
    "    relay.qnn.op.dequantize(w, relay.const(0.5), zero),\n",
    ")\n",
    "op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype=\"int8\")\n",
    "\n",
    "mod = tvm.IRModule.from_expr(op)\n",
    "mod.show()\n",
    "fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)\n",
    "\n",
    "assert dict(fake_quantized_op_freqs) == {\"nn.dense\": 1}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试伪量化多个区域"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">64</span>), int8], <span style=\"color: #AA22FF; font-weight: bold\">%</span>w: Tensor[(<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">64</span>), int8], <span style=\"color: #AA22FF; font-weight: bold\">%</span>w2: Tensor[(<span style=\"color: #008000\">64</span>, <span style=\"color: #008000\">256</span>), int8]) {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #008000\">2</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>w, <span style=\"color: #008000\">0.5</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dense(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>, units<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">3</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>quantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">1</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>f, <span style=\"color: #008000\">114</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">5</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">4</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">6</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>quantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">1</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">7</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">6</span>, <span style=\"color: #008000\">1</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">8</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>w2, <span style=\"color: #008000\">0.5</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">9</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dense(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">7</span>, <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">8</span>, units<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">10</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>quantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">9</span>, <span style=\"color: #008000\">1</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>);\n",
       "  sigmoid(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">10</span>)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = relay.var(\"x\", shape=[128, 64], dtype=\"int8\")\n",
    "w = relay.var(\"w\", shape=[256, 64], dtype=\"int8\")\n",
    "zero = relay.const(0)\n",
    "\n",
    "op = relay.op.nn.dense(\n",
    "    relay.qnn.op.dequantize(x, relay.const(2.0), zero),\n",
    "    relay.qnn.op.dequantize(w, relay.const(0.5), zero),\n",
    ")\n",
    "op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype=\"int8\")\n",
    "\n",
    "op = relay.qnn.op.dequantize(op, relay.const(2.0), relay.const(114))\n",
    "op = relay.op.nn.relu(op)\n",
    "op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype=\"int8\")\n",
    "\n",
    "w2 = relay.var(\"w2\", shape=[64, 256], dtype=\"int8\")\n",
    "op = relay.op.nn.dense(\n",
    "    relay.qnn.op.dequantize(op, relay.const(1.0), zero),\n",
    "    relay.qnn.op.dequantize(w2, relay.const(0.5), zero),\n",
    ")\n",
    "op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype=\"int8\")\n",
    "\n",
    "# We expect to ignore this sigmoid op since it's just outside a fake\n",
    "# quantized region\n",
    "op = relay.op.sigmoid(op)\n",
    "\n",
    "mod = tvm.IRModule.from_expr(op)\n",
    "mod.show()\n",
    "fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)\n",
    "\n",
    "assert dict(fake_quantized_op_freqs) == {\"nn.dense\": 2, \"nn.relu\": 1}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试伪量化 maxpool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), int8]) {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #008000\">2</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>max_pool2d(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, pool_size<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>]);\n",
       "  qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>quantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">2</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = relay.var(\"x\", shape=[1, 3, 224, 224], dtype=\"int8\")\n",
    "\n",
    "zero = relay.const(0)\n",
    "x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)\n",
    "op = relay.op.nn.max_pool2d(x, [3, 3])\n",
    "op = relay.qnn.op.quantize(op, relay.const(2.0), zero)\n",
    "\n",
    "mod = tvm.IRModule.from_expr(op)\n",
    "mod.show()\n",
    "fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)\n",
    "\n",
    "assert dict(fake_quantized_op_freqs) == {\"nn.max_pool2d\": 1}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试伪量化转置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">224</span>, <span style=\"color: #008000\">224</span>), int8]) {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #008000\">2</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> transpose(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, axes<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>]);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> reshape(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>, newshape<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">3</span>, <span style=\"color: #AA22FF; font-weight: bold\">-</span><span style=\"color: #008000\">1</span>]);\n",
       "  qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>quantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = relay.var(\"x\", shape=[1, 3, 224, 224], dtype=\"int8\")\n",
    "\n",
    "zero = relay.const(0)\n",
    "x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)\n",
    "op = relay.op.transpose(x, [1, 0, 2, 3])\n",
    "op = relay.op.reshape(op, [3, -1])\n",
    "op = relay.qnn.op.quantize(op, relay.const(2.0), zero)\n",
    "\n",
    "mod = tvm.IRModule.from_expr(op)\n",
    "mod.show()\n",
    "fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)\n",
    "\n",
    "assert dict(fake_quantized_op_freqs) == {\"transpose\": 1, \"reshape\": 1}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试伪量化 concat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x0: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">4</span>), int8], <span style=\"color: #AA22FF; font-weight: bold\">%</span>x1: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">4</span>), int8], <span style=\"color: #AA22FF; font-weight: bold\">%</span>x2: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">4</span>), int8], <span style=\"color: #AA22FF; font-weight: bold\">%</span>x3: Tensor[(<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">4</span>), int8]) {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x0, <span style=\"color: #008000\">0.5</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x1, <span style=\"color: #008000\">1.5</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x2, <span style=\"color: #008000\">2.5</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">3</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>dequantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x3, <span style=\"color: #008000\">3.5</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> (<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>, <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>, <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">3</span>);\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">5</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> concatenate(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">4</span>, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>);\n",
       "  qnn<span style=\"color: #AA22FF; font-weight: bold\">.</span>quantize(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">3.5</span>f, <span style=\"color: #008000\">0</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int8&quot;</span>)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "zero = relay.const(0)\n",
    "inputs = []\n",
    "for i in range(4):\n",
    "    inputs.append(\n",
    "        relay.qnn.op.dequantize(\n",
    "            relay.var(\"x%d\" % i, shape=[1, 4], dtype=\"int8\"), relay.const(i + 0.5), zero\n",
    "        )\n",
    "    )\n",
    "concat = relay.op.concatenate(inputs, axis=1)\n",
    "op = relay.qnn.op.quantize(concat, relay.const(3.5), zero)\n",
    "\n",
    "mod = tvm.IRModule.from_expr(op)\n",
    "mod.show()\n",
    "fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)\n",
    "\n",
    "assert dict(fake_quantized_op_freqs) == {\"concatenate\": 1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
